{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "view-in-github"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/tutorial_6_CNN.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
        "id": "uSGZ6cdmpknm",
        "pycharm": {
          "name": "#%% md\n"
        }
      },
      "source": [
        "[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/snntorch_alpha_w.png?raw=true' width=\"400\">](https://github.com/jeshraghian/snntorch/)\n",
        "\n",
        "# snnTorch - Surrogate Gradient Descent in a Convolutional Spiking Neural Network\n",
        "## Tutorial 6\n",
        "### By Jason K. Eshraghian (www.jasoneshraghian.com)\n",
        "\n",
        "<a href=\"https://colab.research.google.com/github/jeshraghian/snntorch/blob/master/examples/tutorial_6_CNN.ipynb\">\n",
        "  <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
        "</a>\n",
        "\n",
        "[<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/GitHub-Mark-Light-120px-plus.png?raw=true' width=\"28\">](https://github.com/jeshraghian/snntorch/) [<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/GitHub_Logo_White.png?raw=true' width=\"80\">](https://github.com/jeshraghian/snntorch/)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rugeYYiqsrlc"
      },
      "source": [
        "The snnTorch tutorial series is based on the following paper. If you find these resources or code useful in your work, please consider citing the following source:\n",
        "\n",
        "> <cite> [Jason K. Eshraghian, Max Ward, Emre Neftci, Xinxin Wang, Gregor Lenz, Girish Dwivedi, Mohammed Bennamoun, Doo Seok Jeong, and Wei D. Lu. \"Training Spiking Neural Networks Using Lessons From Deep Learning\". arXiv preprint arXiv:2109.12894, September 2021.](https://arxiv.org/abs/2109.12894) </cite>\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ymi3sqJg28OQ"
      },
      "source": [
        "# Introduction\n",
        "In this tutorial, you will:\n",
        "* Learn how to use surrogate gradient descent to overcome the dead neuron problem\n",
        "* Construct and train a convolutional spiking neural network\n",
        "* Use a sequential container, `nn.Sequential` to simplify model construction\n",
        "* Use the `snn.backprop` module to reduce the time it takes to design a neural network\n",
        "\n",
        "Part of this tutorial was inspired by Friedemann Zenke’s extensive\n",
        "work on SNNs. Check out his repo on surrogate gradients\n",
        "[here](https://github.com/fzenke/spytorch), and a favourite paper\n",
        "of mine: E. O. Neftci, H. Mostafa, F. Zenke, [Surrogate Gradient\n",
        "Learning in Spiking Neural Networks: Bringing the Power of\n",
        "Gradient-based optimization to spiking neural\n",
        "networks.](https://ieeexplore.ieee.org/document/8891809) IEEE\n",
        "Signal Processing Magazine 36, 51–63.\n",
        "\n",
        "At the end of the tutorial, we will train a convolutional spiking neural network (CSNN) using the MNIST dataset to perform image classification. The background theory follows on from [Tutorials 2, 4 and 5](https://snntorch.readthedocs.io/en/latest/tutorials/index.html), so feel free to go back if you need to brush up.\n",
        "\n",
        "If running in Google Colab:\n",
        "* You may connect to GPU by checking `Runtime` > `Change runtime type` > `Hardware accelerator: GPU`\n",
        "* Next, install the latest PyPi distribution of snnTorch by clicking into the following cell and pressing `Shift+Enter`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5tn_wUlopkon",
        "pycharm": {
          "name": "#%%\n"
        }
      },
      "outputs": [],
      "source": [
        "!pip install snntorch"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "QXZ6Tuqc9Q-l"
      },
      "outputs": [],
      "source": [
        "# imports\n",
        "import snntorch as snn\n",
        "from snntorch import surrogate\n",
        "from snntorch import backprop\n",
        "from snntorch import functional as SF\n",
        "from snntorch import utils\n",
        "from snntorch import spikeplot as splt\n",
        "\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "from torch.utils.data import DataLoader\n",
        "from torchvision import datasets, transforms\n",
        "import torch.nn.functional as F\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import itertools"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gt2xMbLY9dVE"
      },
      "source": [
        "# 1. Surrogate Gradient Descent"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zJddJWoa0GT6"
      },
      "source": [
        "[Tutorial 5](https://snntorch.readthedocs.io/en/latest/tutorials/index.html) raised the **dead neuron problem**. This arises because of the non-differentiability of spikes:\n",
        "\n",
        "\n",
        "$$S[t] = \\Theta(U[t] - U_{\\rm thr}) \\tag{1}$$ \n",
        "$$\\frac{\\partial S}{\\partial U} = \\delta(U - U_{\\rm thr}) \\in \\{0, \\infty\\} \\tag{2}$$ \n",
        "\n",
        "where $\\Theta(\\cdot)$ is the Heaviside step function, and $\\delta(\\cdot)$ is the Dirac-Delta function. We previously overcame this using the *Spike-Operator* approach, by assigning the spike to the derivative term: $\\partial \\tilde{S}/\\partial U \\leftarrow S \\in \\{0, 1\\}$. Another approach is to smooth the Heaviside function during the backward pass, which correspondingly smooths out the gradient of the Heaviside function. \n",
        "\n",
        "Common smoothing functions include the sigmoid function, or the fast sigmoid function. The sigmoidal functions must also be shifted such that they are centered at the threshold $U_{\\rm thr}$. Defining  the overdrive of the membrane potential as $U_{OD} = U - U_{\\rm thr}$:\n",
        "\n",
        "$$\\tilde{S} = \\frac{U_{OD}}{1+k|U_{OD}|} \\tag{3}$$\n",
        "$$\\frac{\\partial \\tilde{S}}{\\partial U} = \\frac{1}{(k|U_{OD}|+1)^2}\\tag{4}$$\n",
        "\n",
        "where $k$ modulates how smooth the surrogate function is, and is treated as a hyperparameter. As $k$ increases, the approximation converges towards the original derivative in $(2)$:  \n",
        "\n",
        "$$\\frac{\\partial \\tilde{S}}{\\partial U} \\Bigg|_{k \\rightarrow \\infty} = \\delta(U-U_{\\rm thr})$$ \n",
        "\n",
        "\n",
        "<center>\n",
        "<img src='https://github.com/jeshraghian/snntorch/blob/master/docs/_static/img/examples/tutorial6/surrogate.png?raw=true' width=\"800\">\n",
        "</center>\n",
        "\n",
        "To summarize:\n",
        "\n",
        "* **Forward Pass**\n",
        "  - Determine $S$ using the shifted Heaviside function in $(1)$\n",
        "  - Store $U$ for later use during the backward pass\n",
        "* **Backward Pass**\n",
        "  - Pass $U$ into $(4)$ to calculate the derivative term\n",
        "\n",
        "In the same way the *Spike Operator* approach was used in [Tutorial 5](https://snntorch.readthedocs.io/en/latest/tutorials/index.html), \n",
        "the gradient of the fast sigmoid function can override the Dirac-Delta function in a Leaky Integrate-and-Fire\n",
        "(LIF) neuron model:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5lqpuT--bZmJ"
      },
      "outputs": [],
      "source": [
        "# Leaky neuron model, overriding the backward pass with a custom function\n",
        "class LeakySigmoidSurrogate(nn.Module):\n",
        "  def __init__(self, beta, threshold=1.0, k=25):\n",
        "      super(Leaky_Surrogate, self).__init__()\n",
        "\n",
        "      # initialize decay rate beta and threshold\n",
        "      self.beta = beta\n",
        "      self.threshold = threshold\n",
        "      self.surrogate_func = self.FastSigmoid.apply\n",
        "  \n",
        "  # the forward function is called each time we call Leaky\n",
        "  def forward(self, input_, mem):\n",
        "    spk = self.surrogate_func((mem-self.threshold))  # call the Heaviside function\n",
        "    reset = (spk - self.threshold).detach()\n",
        "    mem = self.beta * mem + input_ - reset\n",
        "    return spk, mem\n",
        "\n",
        "  # Forward pass: Heaviside function\n",
        "  # Backward pass: Override Dirac Delta with gradient of fast sigmoid\n",
        "  @staticmethod\n",
        "  class FastSigmoid(torch.autograd.Function):  \n",
        "    @staticmethod\n",
        "    def forward(ctx, mem, k=25):\n",
        "        ctx.save_for_backward(mem) # store the membrane potential for use in the backward pass\n",
        "        ctx.k = k\n",
        "        out = (mem > 0).float() # Heaviside on the forward pass: Eq(1)\n",
        "        return out\n",
        "\n",
        "    @staticmethod\n",
        "    def backward(ctx, grad_output): \n",
        "        (mem,) = ctx.saved_tensors  # retrieve membrane potential\n",
        "        grad_input = grad_output.clone()\n",
        "        grad = grad_input / (ctx.k * torch.abs(mem) + 1.0) ** 2  # gradient of fast sigmoid on backward pass: Eq(4)\n",
        "        return grad, None"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4aQvFG7ajpzU"
      },
      "source": [
        "Better yet, all of that can be condensed by using the built-in module `snn.surrogate` from snnTorch, where $k$ from $(4)$ is denoted `slope`. The surrogate gradient is passed into `spike_grad` as an argument:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2dCWD_qajyLw"
      },
      "outputs": [],
      "source": [
        "spike_grad = surrogate.fast_sigmoid(slope=25)\n",
        "beta = 0.5\n",
        "\n",
        "lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ffn7D6omkj5r"
      },
      "source": [
        "To explore the other surrogate gradient functions available, [take a look at the documentation here.](https://snntorch.readthedocs.io/en/latest/snntorch.surrogate.html)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Wgzf83HE2BeB"
      },
      "source": [
        "# 2. Setting up the CSNN\n",
        "## 2.1 DataLoaders\n",
        "Note that `utils.data_subset()` is called to reduce the size of the dataset by a factor of 10 to speed up training."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pxO32dntlOB2"
      },
      "outputs": [],
      "source": [
        "# dataloader arguments\n",
        "batch_size = 128\n",
        "data_path='/data/mnist'\n",
        "subset=10\n",
        "\n",
        "dtype = torch.float\n",
        "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pE7eGTnulSBA"
      },
      "outputs": [],
      "source": [
        "# Define a transform\n",
        "transform = transforms.Compose([\n",
        "            transforms.Resize((28, 28)),\n",
        "            transforms.Grayscale(),\n",
        "            transforms.ToTensor(),\n",
        "            transforms.Normalize((0,), (1,))])\n",
        "\n",
        "mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)\n",
        "mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)\n",
        "\n",
        "# reduce datasets by 10x to speed up training\n",
        "utils.data_subset(mnist_train, subset)\n",
        "utils.data_subset(mnist_test, subset)\n",
        "\n",
        "# Create DataLoaders\n",
        "train_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)\n",
        "test_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6v8fBXrVlY3f"
      },
      "source": [
        "## 2.2 Define the Network\n",
        "\n",
        "The convolutional network architecture to be used is: 12C5-MP2-64C5-MP2-1024FC10\n",
        "\n",
        "- 12C5 is a 5$\\times$5 convolutional kernel with 12 filters\n",
        "- MP2 is a 2$\\times$2 max-pooling function\n",
        "- 1024FC10 is a fully-connected layer that maps 1,024 neurons to 10 outputs"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "foos_NlopDrb"
      },
      "outputs": [],
      "source": [
        "# neuron and simulation parameters\n",
        "spike_grad = surrogate.fast_sigmoid(slope=25)\n",
        "beta = 0.5\n",
        "num_steps = 50"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "X4sd8PDSlGZb"
      },
      "outputs": [],
      "source": [
        "# Define Network\n",
        "class Net(nn.Module):\n",
        "    def __init__(self):\n",
        "        super().__init__()\n",
        "\n",
        "        # Initialize layers\n",
        "        self.conv1 = nn.Conv2d(1, 12, 5)\n",
        "        self.lif1 = snn.Leaky(beta=beta, spike_grad=spike_grad)\n",
        "        self.conv2 = nn.Conv2d(12, 64, 5)\n",
        "        self.lif2 = snn.Leaky(beta=beta, spike_grad=spike_grad)\n",
        "        self.fc1 = nn.Linear(64*4*4, 10)\n",
        "        self.lif3 = snn.Leaky(beta=beta, spike_grad=spike_grad)\n",
        "\n",
        "    def forward(self, x):\n",
        "\n",
        "        # Initialize hidden states and outputs at t=0\n",
        "        mem1 = self.lif1.init_leaky()\n",
        "        mem2 = self.lif2.init_leaky() \n",
        "        mem3 = self.lif3.init_leaky()\n",
        "\n",
        "        # Record the final layer\n",
        "        spk3_rec = []\n",
        "        mem3_rec = []\n",
        "\n",
        "        for step in range(num_steps):\n",
        "            cur1 = F.max_pool2d(self.conv1(x), 2)\n",
        "            spk1, mem1 = self.lif1(cur1, mem1)\n",
        "            cur2 = F.max_pool2d(self.conv2(spk1), 2)\n",
        "            spk2, mem2 = self.lif2(cur2, mem2)\n",
        "            cur3 = self.fc1(spk2.view(batch_size, -1))\n",
        "            spk3, mem3 = self.lif3(cur3, mem3)\n",
        "\n",
        "            spk3_rec.append(spk3)\n",
        "            mem3_rec.append(mem3)\n",
        "\n",
        "        return torch.stack(spk3_rec), torch.stack(mem3_rec)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HVn3aYAUnWqH"
      },
      "source": [
        "In the previous tutorial, the network was wrapped inside of a class, as shown above. \n",
        "With increasing network complexity, this adds a lot of boilerplate code that we might wish to avoid. Alternatively, the `nn.Sequential` method can be used instead:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AoYBY89angvp"
      },
      "outputs": [],
      "source": [
        "#  Initialize Network\n",
        "net = nn.Sequential(nn.Conv2d(1, 12, 5),\n",
        "                    nn.MaxPool2d(2),\n",
        "                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),\n",
        "                    nn.Conv2d(12, 64, 5),\n",
        "                    nn.MaxPool2d(2),\n",
        "                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True),\n",
        "                    nn.Flatten(),\n",
        "                    nn.Linear(64*4*4, 10),\n",
        "                    snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True)\n",
        "                    ).to(device)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7Qgw1dRmpOlo"
      },
      "source": [
        "\n",
        "The `init_hidden` argument initializes the hidden states of the neuron (here, membrane potential). This takes place in the background as an instance variable. \n",
        "If `init_hidden` is activated, the membrane potential is not explicitly returned to the user, ensuring only the output spikes are sequentially passed through the layers wrapped in `nn.Sequential`. \n",
        "\n",
        "To train a model using the final layer's membrane potential, set the argument `output=True`. \n",
        "This enables the final layer to return both the spike and membrane potential response of the neuron."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A-bCSQmBstvd"
      },
      "source": [
        "## 2.3 Forward-Pass\n",
        "A forward pass across a simulation duration of `num_steps` looks like this:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IxPPMND-pMxE"
      },
      "outputs": [],
      "source": [
        "data, targets = next(iter(train_loader))\n",
        "data = data.to(device)\n",
        "targets = targets.to(device)\n",
        "\n",
        "for step in range(num_steps):\n",
        "    spk_out, mem_out = net(data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "C3PxtCobuH_e"
      },
      "source": [
        "Wrap that in a function, recording the membrane potential and spike response over time:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ykdnD3tRuHcs"
      },
      "outputs": [],
      "source": [
        "def forward_pass(net, num_steps, data):\n",
        "  mem_rec = []\n",
        "  spk_rec = []\n",
        "  utils.reset(net)  # resets hidden states for all LIF neurons in net\n",
        "\n",
        "  for step in range(num_steps):\n",
        "      spk_out, mem_out = net(data)\n",
        "      spk_rec.append(spk_out)\n",
        "      mem_rec.append(mem_out)\n",
        "  \n",
        "  return torch.stack(spk_rec), torch.stack(mem_rec)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "unJrx3pXcXii"
      },
      "outputs": [],
      "source": [
        "spk_rec, mem_rec = forward_pass(net, num_steps, data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zqJdfllYbc16"
      },
      "source": [
        "# 3. Training Loop"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x-kquOWLY1Jo"
      },
      "source": [
        "## 3.1 Loss Using snn.Functional"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MlA56BgOYq1D"
      },
      "source": [
        "In the previous tutorial, the Cross Entropy Loss between the membrane potential of the output neurons and the target was used to train the network. \n",
        "This time, the total number of spikes from each neuron will be used to calculate the Cross Entropy instead.\n",
        "\n",
        "A variety of loss functions are included in the `snn.functional` module, which is analogous to `torch.nn.functional` in PyTorch. \n",
        "These implement a mix of cross entropy and mean square error losses, are applied to spikes and/or membrane potential, to train a rate or latency-coded network. \n",
        "\n",
        "The approach below applies the cross entropy loss to the output spike count in order train a rate-coded network:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UZ2BM6d6a11l"
      },
      "outputs": [],
      "source": [
        "# already imported snntorch.functional as SF \n",
        "loss_fn = SF.ce_rate_loss()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "q39HCIeOa4fC"
      },
      "source": [
        "The recordings of the spike are passed as the first argument to `loss_fn`, and the target neuron index as the second argument to generate a loss. [The documentation provides further information and exmaples.](https://snntorch.readthedocs.io/en/latest/snntorch.functional.html#snntorch.functional.ce_rate_loss)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xEVzMvujcjsE"
      },
      "outputs": [],
      "source": [
        "loss_val = loss_fn(spk_rec, targets)\n",
        "\n",
        "print(f\"The loss from an untrained network is {loss_val.item():.3f}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HS_zeb5mbqjw"
      },
      "source": [
        "## 3.2 Accuracy Using snn.Functional\n",
        "The `SF.accuracy_rate()` function works similarly, in that the predicted output spikes and actual targets are supplied as arguments. `accuracy_rate` assumes a rate code is used to interpret the output by checking if the index of the neuron with the highest spike count matches the target index."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yq7_rly0c9b4"
      },
      "outputs": [],
      "source": [
        "acc = SF.accuracy_rate(spk_rec, targets)\n",
        "\n",
        "print(f\"The accuracy of a single batch using an untrained network is {acc*100:.3f}%\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r4Z6bnqCdL50"
      },
      "source": [
        "As the above function only returns the accuracy of a single batch of data, the following function returns the accuracy on the entire DataLoader object:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IqxDKFvrdXuF"
      },
      "outputs": [],
      "source": [
        "def batch_accuracy(train_loader, net, num_steps):\n",
        "  with torch.no_grad():\n",
        "    total = 0\n",
        "    acc = 0\n",
        "    net.eval()\n",
        "    \n",
        "    train_loader = iter(train_loader)\n",
        "    for data, targets in train_loader:\n",
        "      data = data.to(device)\n",
        "      targets = targets.to(device)\n",
        "      spk_rec, _ = forward_pass(net, num_steps, data)\n",
        "\n",
        "      acc += SF.accuracy_rate(spk_rec, targets) * spk_rec.size(1)\n",
        "      total += spk_rec.size(1)\n",
        "\n",
        "  return acc/total"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_u43hKAvefWM"
      },
      "outputs": [],
      "source": [
        "test_acc = batch_accuracy(test_loader, net, num_steps)\n",
        "\n",
        "print(f\"The total accuracy on the test set is: {test_acc * 100:.2f}%\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F1pzWXlsYoIu"
      },
      "source": [
        "## 3.3 Training Automation Using snn.backprop"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KAUDYl3gf0G-"
      },
      "source": [
        "Training SNNs can become arduous even with simple networks, so the `snn.backprop` module is here to reduce some of this effort.\n",
        "\n",
        "The `backprop.BPTT` function automatically performs a single epoch of training, where you need only provide the training parameters, dataloader, and several other arguments. \n",
        "The average loss across iterations is returned. \n",
        "The argument `time_var` indicates whether the input data is time-varying. \n",
        "As we are using the MNIST dataset, we explicitly specify `time_var=False`. \n",
        "\n",
        "The following code block may take a while to run. If you are not connected to GPU, then consider reducing `num_epochs`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dnS4wYyh0bdb"
      },
      "outputs": [],
      "source": [
        "optimizer = torch.optim.Adam(net.parameters(), lr=1e-2, betas=(0.9, 0.999))\n",
        "num_epochs = 10\n",
        "test_acc_hist = []\n",
        "\n",
        "# training loop\n",
        "for epoch in range(num_epochs):\n",
        "\n",
        "    avg_loss = backprop.BPTT(net, train_loader, optimizer=optimizer, criterion=loss_fn, \n",
        "                            num_steps=num_steps, time_var=False, device=device)\n",
        "    \n",
        "    print(f\"Epoch {epoch}, Train Loss: {avg_loss.item():.2f}\")\n",
        "\n",
        "    # Test set accuracy\n",
        "    test_acc = batch_accuracy(train_loader, net, num_steps)\n",
        "    test_acc_hist.append(test_acc)\n",
        "\n",
        "    print(f\"Epoch {epoch}, Test Acc: {test_acc * 100:.2f}%\\n\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TjRPDFWxj2eS"
      },
      "source": [
        "Despite having selected some fairly generic values and architectures, the test set accuracy should be fairly competitive given the brief training run!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "collapsed": false,
        "id": "HxU7P7xFpko3",
        "pycharm": {
          "name": "#%% md\n"
        }
      },
      "source": [
        "# 4. Results\n",
        "## 4.1 Plot Test Accuracy"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_Pk_EScnpkpj",
        "pycharm": {
          "name": "#%%\n"
        }
      },
      "outputs": [],
      "source": [
        "# Plot Loss\n",
        "fig = plt.figure(facecolor=\"w\")\n",
        "plt.plot(test_acc_hist)\n",
        "plt.title(\"Test Set Accuracy\")\n",
        "plt.xlabel(\"Epoch\")\n",
        "plt.ylabel(\"Accuracy\")\n",
        "plt.show()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nYFamUJLkVY3"
      },
      "source": [
        "## 4.2 Spike Counter"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MDE3ms9ulo-t"
      },
      "source": [
        "Run a forward pass on a batch of data to obtain spike and membrane readings."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dqIjaw1kk4-6"
      },
      "outputs": [],
      "source": [
        "spk_rec, mem_rec = forward_pass(net, num_steps, data)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n4QiJx2HlkMH"
      },
      "source": [
        "Changing `idx` allows you to index into various samples from the simulated minibatch. Use `splt.spike_count` to explore the spiking behaviour of a few different samples!\n",
        "\n",
        "> Note: if you are running the notebook locally on your desktop, please uncomment the line below and modify the path to your ffmpeg.exe\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4HcwxfC6kfy0"
      },
      "outputs": [],
      "source": [
        "from IPython.display import HTML\n",
        "\n",
        "idx = 0\n",
        "\n",
        "fig, ax = plt.subplots(facecolor='w', figsize=(12, 7))\n",
        "labels=['0', '1', '2', '3', '4', '5', '6', '7', '8','9']\n",
        "print(f\"The target label is: {targets[idx]}\")\n",
        "\n",
        "# plt.rcParams['animation.ffmpeg_path'] = 'C:\\\\path\\\\to\\\\your\\\\ffmpeg.exe'\n",
        "\n",
        "#  Plot spike count histogram\n",
        "anim = splt.spike_count(spk_rec[:, idx].detach().cpu(), fig, ax, labels=labels, \n",
        "                        animate=True, interpolate=4)\n",
        "\n",
        "HTML(anim.to_html5_video())\n",
        "# anim.save(\"spike_bar.mp4\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s0dAgWUt2o6E"
      },
      "source": [
        "# Conclusion\n",
        "You should now have a grasp of the basic features of snnTorch and be able to start running your own experiments. [In the next tutorial](https://snntorch.readthedocs.io/en/latest/tutorials/index.html), we will train a network using a neuromorphic dataset.\n",
        "\n",
        "If you like this project, please consider starring ⭐ the repo on GitHub as it is the easiest and best way to support it."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# Additional Resources\n",
        "* [Check out the snnTorch GitHub project here.](https://github.com/jeshraghian/snntorch)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "include_colab_link": true,
      "name": "tutorial_6_CNN.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 2
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython2",
      "version": "2.7.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 2
}
